"""
PyTorch implementation of the DADA optimizers
"""
import logging
from typing import Optional

import torch
from torch.optim import Optimizer

logger = logging.getLogger(__name__)


class DADA(Optimizer):
    def __init__(self, params, reps_rel: float = 1e-6, eps: float = 1e-8,
                 init_eta: Optional[float] = None):
        if init_eta is not None:
            if init_eta <= 0:
                raise ValueError(f'Invalid value for init_eta ({init_eta})')
            logger.info(f'Ignoring reps_rel since will be explicitly set init_eta to be {init_eta} (first step size)')
            reps_rel = 0
        else:
            if reps_rel <= 0.0:
                raise ValueError(f'Invalid reps_rel value ({reps_rel}). Suggested value is 1e-6 '
                                 '(unless the model uses batch-normalization, in which case suggested value is 1e-4)')

        self._first_step = True

        defaults = dict(reps_rel=reps_rel, eps=eps, init_eta=init_eta)
        super(DADA, self).__init__(params, defaults)

        # Save the initial values of parameters
        for group in self.param_groups:
            for param in group['params']:
                if param.requires_grad:
                    # Save the initial parameter values in the state
                    self.state[param] = {
                        'grad_avg': torch.zeros_like(param),
                        'init_point': param.clone().detach()  # Save the initial point
                    }

    def __setstate__(self, state):
        super(DADA, self).__setstate__(state)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        first_step = self._first_step

        for group in self.param_groups:

            if first_step:
                init = group['init_buffer'] = [torch.clone(p).detach() for p in group['params']]
            else:
                init = group['init_buffer']

            self._update_group_state(group, init)

            for p in group['params']:
                if p.grad is None:
                    continue
                else:
                    p.data = self.state[p]['init_point'] - self.state[p]['eta']
                    # p.data = - self.state[p]['eta']

        self._first_step = False

        return loss

    def _update_group_state(self, group, init):
        k = group.get('step', 0) + 1
        group['step'] = k
        eta_scale = 1 / (2 * torch.sqrt(2 * torch.tensor(k)))

        # treat all layers as one long vector
        if self._first_step:
            group['rbar'] = group['reps_rel'] * (1 + torch.stack([p.norm() for p in group['params']]).norm())

        else:
            curr_d = torch.stack([torch.norm(p.detach() - pi) for p, pi in zip(group['params'], init)]).norm()
            group['rbar'] = torch.maximum(group['rbar'], curr_d)

        # Update grad_avg as a tensor
        for p in group['params']:
            g = p.grad.detach().clone()
            self.state[p]['grad_avg'] += (group['rbar'] / (torch.sqrt((g ** 2).sum()))) * g
            # self.state[p]['grad_avg'] += (group['rbar'] / (g.norm())) * g
            self.state[p]['eta'] = eta_scale * self.state[p]['grad_avg']

    def has_d_estimator(self):
        return True

    def calculate_d_estimation_error(self, actual_d):
        return actual_d / self.param_groups[0]['rbar']